from flax.linen import *
import scalevi.nn.initializers as initializers 
import jax.numpy as np
from typing import Any, Callable, Sequence, Optional, Tuple, Union

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any  # this could be a real type?
Array = Any


class Bilinear(Module):
    features: int
    use_bias: bool = False
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.normal(0.001)
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
    @compact
    def __call__(self, x: Array, y: Array) -> Array:
        kernel = self.param(
                    'kernel',
                    self.kernel_init,
                    (self.features, x.shape[-1], y.shape[-1]))
        kernel = np.asarray(kernel)
        y = np.einsum(
                    "...i, cik, ...k -> ...c",
                    x, kernel, y, 
                    optimize='greedy')
        if self.use_bias:
            bias = self.param(
                        "bias",
                        self.bias_init,
                        (self.features,))
            bias = np.asarray(bias)
            y = y + bias
        return y

class Block(Module):
    """A simple Enocder Block"""
    feature: int
    
    @compact
    def __call__(self, x: Array)-> Array:
        name = "Dense_Block" if self.name is None else self.name
        x = Dense(self.feature, name=f"{name} Dense")(x)
        x = leaky_relu(x)
        return x

class FCN(Module):
    """A simple multi-layer fully connected network"""
    features: Sequence[int]

    @compact
    def __call__(self, x: Array)-> Array:
        for i, feature in enumerate(self.features):
            name = 'FCN_Block' if self.name is None else self.name
            if i != len(self.features)-1:
                x = Block(feature = feature, name=f"{name}_{i}")(x)
            else:
                x = Dense(feature, name=f"{name}_{i}")(x)
        return x


class Stats(Module):
    """A Module to get statistics.
    
    Given a 2D array X as input, it generates mean 
    statistics for X and X**2, along the 0 axis. 
    If keep_sum_stats is True then this also returns
    the sum statistics for both (X and X**2)."""

    keep_sum_stats: bool = False

    @compact
    def __call__(self, x: Array)-> Array:

        if self.keep_sum_stats:
            return np.concatenate([
                np.sum(x, 0), 
                np.mean(x, 0),
                np.sum(x**2,0),
                np.mean(x**2,0)
                ])    
        return np.concatenate([
            np.mean(x, 0),
            np.mean(x**2,0)
            ])    

_apply_mask = lambda x, mask: (mask*x.T).T 

class MaskedStats(Stats):
    """A masked version of the Stats Module"""

    @compact
    def __call__(self, x: Array, mask: Array)-> Array:
        """
        Args:
            x (Array): 2-D array to be masked
            mask (Array): A boolean mask for the 0 axis (for rows.)
        """
        masked_x = _apply_mask(x, mask)
        if self.keep_sum_stats:
            return np.concatenate([
                np.sum(masked_x, 0), 
                np.sum(masked_x, 0)/np.sum(mask), 
                np.sum(masked_x**2,0),
                np.sum(masked_x**2,0)/np.sum(mask)
                ])    
        return np.concatenate([
            np.sum(masked_x, 0)/np.sum(mask), 
            np.sum(masked_x**2,0)/np.sum(mask)
            ])    